"""Beaker metrics directly from the database."""
from collections import defaultdict
from collections import namedtuple
import csv
import datetime
import functools
import io
import itertools
import os
import re

import boto3
from cki_lib import mariadb
from cki_lib import misc
from cki_lib import s3bucket
from cki_lib import yaml
from cki_lib.cronjob import CronJob
from cki_lib.logger import get_logger
import prometheus_client

LOGGER = get_logger(__name__)

BEAKER_CONFIG = yaml.load(contents=os.environ.get('BEAKER_CONFIG', ''))
NON_FINISHED_STATUSES = [
    'Installing', 'New', 'Processed', 'Queued', 'Running', 'Scheduled', 'Waiting',
]
ALL_PRIORITIES = {'Low', 'Medium', 'Normal', 'High', 'Urgent'}
USER_NAMES = ['beaker/cki-team-automation']

TaskData = namedtuple('TaskData',
                      'queue_time owner_name job_id recipe_set_id recipe_id task_id task_name')
TaskStatistic = namedtuple('TaskStatistic', 'id name count')


class BeakerDBMetrics(CronJob):
    """Calculate Beaker metrics directly from the database."""

    schedule = '*/5 * * * *'

    metric_recipes_in_queue = prometheus_client.Gauge(
        'cki_beaker_recipes_in_queue',
        'Number of jobs in a Beaker queue',
        ['queue']
    )

    metric_recipes_by_status = prometheus_client.Gauge(
        'cki_beaker_recipes_by_status',
        'Number of non-finished recipes per status',
        ['status', 'priority']
    )

    metric_jobs = prometheus_client.Gauge(
        'cki_beaker_job_count',
        'Number of jobs',
        ['user']
    )

    metric_recipes_queue_time = prometheus_client.Histogram(
        'cki_beaker_recipes_queue_time',
        'Queue time for all currently queued recipes in seconds',
        buckets=[3600 * h for h in (0.5, 1, 3, 6, 12, 24, 48)],
    )

    metric_tasks = prometheus_client.Gauge(
        'cki_beaker_tasks',
        'Total number of instantiated tasks from the task library',
        ['id', 'name']
    )

    @staticmethod
    def _upload_s3(key: str, raw: bytes) -> None:
        spec = s3bucket.parse_bucket_spec(os.environ[BEAKER_CONFIG['bucket']])
        LOGGER.debug('Uploading to %s/%s', spec.bucket, spec.prefix)
        client = boto3.Session().client('s3',
                                        aws_access_key_id=spec.access_key or None,
                                        aws_secret_access_key=spec.secret_key or None,
                                        endpoint_url=spec.endpoint or None)
        client.put_object(Bucket=spec.bucket,
                          Key=f'{spec.prefix}{key}',
                          Body=raw)

    @functools.cached_property
    def db_handler(self):
        """Get DB connection. Not done in init for simpler testing."""
        return mariadb.MariaDBHandler()

    def get_queued_recipes_constraints(self):
        """Get distro_requires for the queued recipes."""
        query = f"""
            SELECT
                r._distro_requires
            FROM
                beaker.recipe AS r
            JOIN
                beaker.recipe_set AS rs ON rs.id = r.recipe_set_id
            JOIN
                beaker.job AS j ON j.id = rs.job_id
            JOIN
                beaker.tg_user AS u ON u.user_id = j.owner_id
            WHERE
                r.status = 'Queued'
            AND
                u.user_name IN {self.db_handler.tuple_placeholder(USER_NAMES)};
        """
        return [r[0] for r in self.db_handler.execute(query, USER_NAMES)]

    def get_recipes_by_status(self):
        """Get recipes grouped by status."""
        query = f"""
            SELECT
                r.status,
                rs.priority,
                count(r.status)
            FROM
                beaker.recipe AS r
            JOIN
                beaker.recipe_set AS rs ON rs.id = r.recipe_set_id
            JOIN
                beaker.job AS j ON j.id = rs.job_id
            JOIN
                beaker.tg_user AS u ON u.user_id = j.owner_id
            WHERE
                u.user_name IN {self.db_handler.tuple_placeholder(USER_NAMES)}
            AND
                r.status IN {self.db_handler.tuple_placeholder(NON_FINISHED_STATUSES)}
            GROUP BY
                r.status, rs.priority;
        """
        return [(r[0], r[1], int(r[2]))
                for r in self.db_handler.execute(query, USER_NAMES + NON_FINISHED_STATUSES)]

    @staticmethod
    def _recipe_matches_queue(distro_requires, architectures):
        """Check if the distro_requires matches all the specified contraints."""
        required_architectures = re.findall(r'distro_arch op="=" value="(\S+)"', distro_requires)
        return set(required_architectures) == set(architectures)

    def update_cki_beaker_recipes_in_queue(self):
        """Update cki_beaker_recipes_in_queue metric."""
        queues = defaultdict(int)
        queued_recipes = self.get_queued_recipes_constraints()

        for queue_name, architectures in BEAKER_CONFIG.get('queues', {}).items():
            architectures = architectures or [queue_name]
            for recipe_distro_required in queued_recipes:
                if self._recipe_matches_queue(recipe_distro_required, architectures):
                    queues[queue_name] += 1

            self.metric_recipes_in_queue.labels(queue_name).set(queues[queue_name])

    def update_cki_beaker_recipes_by_status(self):
        """Update cki_beaker_recipes_by_status metric."""
        recipes_by_status = self.get_recipes_by_status()

        for status_name, priority_name, count in recipes_by_status:
            self.metric_recipes_by_status.labels(status_name, priority_name).set(count)

        # Set all labels not returned by get_recipes_by_status to 0.
        all_labels = set(itertools.product(NON_FINISHED_STATUSES, ALL_PRIORITIES))
        set_labels = {(recipe[0], recipe[1]) for recipe in recipes_by_status}
        # sorted() for reproducible testing.
        for status_name, priority_name in sorted(all_labels - set_labels):
            self.metric_recipes_by_status.labels(status_name, priority_name).set(0)

    def get_jobs(self):
        """Get jobs."""
        query = f"""
            SELECT
                u.user_name,
                count(u.user_name)
            FROM
                beaker.job AS j
            JOIN
                beaker.tg_user AS u ON u.user_id = j.owner_id
            WHERE
                u.user_name IN {self.db_handler.tuple_placeholder(USER_NAMES)}
            GROUP BY
                u.user_name;
        """
        return [(r[0], int(r[1])) for r in self.db_handler.execute(query, USER_NAMES)]

    def update_cki_beaker_jobs(self):
        """Update cki_beaker_jobs metric."""
        jobs = self.get_jobs()

        for user_name, count in jobs:
            self.metric_jobs.labels(user_name).set(count)

        # Set all labels not returned by get_jobs to 0.
        set_labels = {job[0] for job in jobs}
        # sorted() for reproducible testing.
        for user_name in sorted(set(USER_NAMES) - set_labels):
            self.metric_jobs.labels(user_name).set(0)

    def get_recipes_queued_times(self):
        """Get recipes queued time."""
        query = f"""
            SELECT
               TIMESTAMPDIFF(SECOND, rs.queue_time, NOW())
            FROM
                beaker.recipe AS r
            JOIN
                beaker.recipe_set AS rs ON rs.id = r.recipe_set_id
            JOIN
                beaker.job AS j ON j.id = rs.job_id
            JOIN
                beaker.tg_user AS u ON u.user_id = j.owner_id
            WHERE
                u.user_name IN {self.db_handler.tuple_placeholder(USER_NAMES)}
            AND
                r.status = 'Queued';
        """
        return [datetime.timedelta(seconds=r[0])
                for r in self.db_handler.execute(query, USER_NAMES)]

    def update_cki_beaker_recipes_queue_time(self):
        """Update cki_beaker_recipes_queue_time metric."""
        queue_times = self.get_recipes_queued_times()
        for queue_time in queue_times:
            self.metric_recipes_queue_time.observe(queue_time.total_seconds())

    def get_task_data(self, tasks: list[str], task_age: datetime.timedelta):
        """Get task data."""
        if not tasks:
            return []
        query = f"""
            SELECT
                rs.queue_time,
                u.user_name,
                j.id,
                rs.id,
                r.id,
                t.id,
                t.name
            FROM
                beaker.task AS t
            LEFT JOIN
                beaker.recipe_task AS rt ON rt.task_id = t.id
            LEFT JOIN
                beaker.recipe AS r ON r.id = rt.recipe_id
            LEFT JOIN
                beaker.recipe_set AS rs ON rs.id = r.recipe_set_id
            LEFT JOIN
                beaker.job AS j ON j.id = rs.job_id
            LEFT JOIN
                beaker.tg_user AS u ON u.user_id = j.owner_id
            WHERE
                t.name IN {self.db_handler.tuple_placeholder(tasks)}
            AND
                rs.queue_time > %s;
        """
        return [TaskData(*r)._asdict() for r in self.db_handler.execute(query, tasks + [
            datetime.date.today() - task_age,
        ])]

    def get_task_statistics(self, tasks: list[str]):
        """Get task statistics."""
        if not tasks:
            return []
        query = f"""
            SELECT
                t.id,
                t.name,
                COUNT(*)
            FROM
                beaker.task AS t
            LEFT JOIN
                beaker.recipe_task AS rt ON rt.task_id = t.id
            WHERE
                t.name IN {self.db_handler.tuple_placeholder(tasks)}
            GROUP BY
                t.id, t.name;
        """
        return [TaskStatistic(*r) for r in self.db_handler.execute(query, tasks)]

    def update_task_report(self):
        """Update the task report in CSV format on S3."""
        task_age = misc.parse_timedelta(BEAKER_CONFIG.get('task_report_age', '7d'))
        task_data = self.get_task_data(BEAKER_CONFIG.get('tasks'), task_age)
        writer = csv.DictWriter(body := io.StringIO(newline=''), fieldnames=TaskData._fields)
        writer.writeheader()
        writer.writerows(task_data)
        self._upload_s3('tasks.csv', body.getvalue().encode('utf8'))

    def update_task_metrics(self):
        """Update the Prometheus task metrics."""
        for task in self.get_task_statistics(BEAKER_CONFIG.get('tasks')):
            self.metric_tasks.labels(str(task.id), task.name).set(task.count)

    def run(self, **_):
        """Update all metrics."""
        self.update_cki_beaker_recipes_in_queue()
        self.update_cki_beaker_recipes_by_status()
        self.update_cki_beaker_recipes_queue_time()
        self.update_cki_beaker_jobs()
        self.update_task_report()
        self.update_task_metrics()
